{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 决策树\n",
    "\n",
    "- ID3（基于信息增益）\n",
    "- C4.5（基于信息增益比）\n",
    "- CART（gini指数）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### entropy：$H(x) = -\\sum_{i=1}^{n}p_i\\log{p_i}$\n",
    "\n",
    "#### conditional entropy: $H(X|Y)=\\sum{P(X|Y)}\\log{P(X|Y)}$\n",
    "\n",
    "#### information gain : $g(D, A)=H(D)-H(D|A)$\n",
    "\n",
    "#### information gain ratio: $g_R(D, A) = \\frac{g(D,A)}{H(A)}$\n",
    "\n",
    "#### gini index:$Gini(D)=\\sum_{k=1}^{K}p_k\\log{p_k}=1-\\sum_{k=1}^{K}p_k^2$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from collections import Counter\n",
    "import math\n",
    "from math import log\n",
    "\n",
    "import pprint"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "书上题目5.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 书上题目5.1\n",
    "def create_data():\n",
    "    datasets = [['青年', '否', '否', '一般', '否'],\n",
    "               ['青年', '否', '否', '好', '否'],\n",
    "               ['青年', '是', '否', '好', '是'],\n",
    "               ['青年', '是', '是', '一般', '是'],\n",
    "               ['青年', '否', '否', '一般', '否'],\n",
    "               ['中年', '否', '否', '一般', '否'],\n",
    "               ['中年', '否', '否', '好', '否'],\n",
    "               ['中年', '是', '是', '好', '是'],\n",
    "               ['中年', '否', '是', '非常好', '是'],\n",
    "               ['中年', '否', '是', '非常好', '是'],\n",
    "               ['老年', '否', '是', '非常好', '是'],\n",
    "               ['老年', '否', '是', '好', '是'],\n",
    "               ['老年', '是', '否', '好', '是'],\n",
    "               ['老年', '是', '否', '非常好', '是'],\n",
    "               ['老年', '否', '否', '一般', '否'],\n",
    "               ]\n",
    "    labels = [u'年龄', u'有工作', u'有自己的房子', u'信贷情况', u'类别']\n",
    "    # 返回数据集和每个维度的名称\n",
    "    return datasets, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets, labels = create_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = pd.DataFrame(datasets, columns=labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>年龄</th>\n",
       "      <th>有工作</th>\n",
       "      <th>有自己的房子</th>\n",
       "      <th>信贷情况</th>\n",
       "      <th>类别</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>青年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>一般</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>青年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>好</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>青年</td>\n",
       "      <td>是</td>\n",
       "      <td>否</td>\n",
       "      <td>好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>青年</td>\n",
       "      <td>是</td>\n",
       "      <td>是</td>\n",
       "      <td>一般</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>青年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>一般</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>中年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>一般</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>中年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>好</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>中年</td>\n",
       "      <td>是</td>\n",
       "      <td>是</td>\n",
       "      <td>好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>中年</td>\n",
       "      <td>否</td>\n",
       "      <td>是</td>\n",
       "      <td>非常好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>中年</td>\n",
       "      <td>否</td>\n",
       "      <td>是</td>\n",
       "      <td>非常好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>老年</td>\n",
       "      <td>否</td>\n",
       "      <td>是</td>\n",
       "      <td>非常好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>老年</td>\n",
       "      <td>否</td>\n",
       "      <td>是</td>\n",
       "      <td>好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>老年</td>\n",
       "      <td>是</td>\n",
       "      <td>否</td>\n",
       "      <td>好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>老年</td>\n",
       "      <td>是</td>\n",
       "      <td>否</td>\n",
       "      <td>非常好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>老年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>一般</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    年龄 有工作 有自己的房子 信贷情况 类别\n",
       "0   青年   否      否   一般  否\n",
       "1   青年   否      否    好  否\n",
       "2   青年   是      否    好  是\n",
       "3   青年   是      是   一般  是\n",
       "4   青年   否      否   一般  否\n",
       "5   中年   否      否   一般  否\n",
       "6   中年   否      否    好  否\n",
       "7   中年   是      是    好  是\n",
       "8   中年   否      是  非常好  是\n",
       "9   中年   否      是  非常好  是\n",
       "10  老年   否      是  非常好  是\n",
       "11  老年   否      是    好  是\n",
       "12  老年   是      否    好  是\n",
       "13  老年   是      否  非常好  是\n",
       "14  老年   否      否   一般  否"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 熵\n",
    "def calc_ent(datasets):\n",
    "    data_length = len(datasets)\n",
    "    label_count = {}\n",
    "    for i in range(data_length):\n",
    "        label = datasets[i][-1]\n",
    "        if label not in label_count:\n",
    "            label_count[label] = 0\n",
    "        label_count[label] += 1\n",
    "    ent = -sum([(p/data_length)*log(p/data_length, 2) for p in label_count.values()])\n",
    "    return ent\n",
    "\n",
    "# 经验条件熵\n",
    "def cond_ent(datasets, axis=0):\n",
    "    data_length = len(datasets)\n",
    "    feature_sets = {}\n",
    "    for i in range(data_length):\n",
    "        feature = datasets[i][axis]\n",
    "        if feature not in feature_sets:\n",
    "            feature_sets[feature] = []\n",
    "        feature_sets[feature].append(datasets[i])\n",
    "    cond_ent = sum([(len(p)/data_length)*calc_ent(p) for p in feature_sets.values()])\n",
    "    return cond_ent\n",
    "\n",
    "# 信息增益\n",
    "def info_gain(ent, cond_ent):\n",
    "    return ent - cond_ent\n",
    "\n",
    "def info_gain_train(datasets):\n",
    "    count = len(datasets[0]) - 1\n",
    "    ent = calc_ent(datasets)\n",
    "    best_feature = []\n",
    "    for c in range(count):\n",
    "        c_info_gain = info_gain(ent, cond_ent(datasets, axis=c))\n",
    "        best_feature.append((c, c_info_gain))\n",
    "        print('特征({}) - info_gain - {:.3f}'.format(labels[c], c_info_gain))\n",
    "    # 比较大小\n",
    "    best_ = max(best_feature, key=lambda x: x[-1])\n",
    "    return '特征({})的信息增益最大，选择为根节点特征'.format(labels[best_[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "特征(年龄) - info_gain - 0.083\n",
      "特征(有工作) - info_gain - 0.324\n",
      "特征(有自己的房子) - info_gain - 0.420\n",
      "特征(信贷情况) - info_gain - 0.363\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'特征(有自己的房子)的信息增益最大，选择为根节点特征'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "info_gain_train(np.array(datasets))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "---\n",
    "\n",
    "利用ID3算法生成决策树，例5.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义节点类 二叉树\n",
    "class Node:\n",
    "    def __init__(self, root=True, label=None, feature_name=None, feature=None):\n",
    "        self.root = root\n",
    "        self.label = label\n",
    "        self.feature_name = feature_name\n",
    "        self.feature = feature\n",
    "        self.tree = {}\n",
    "        self.result = {'label:': self.label, 'feature': self.feature, 'tree': self.tree}\n",
    "\n",
    "    def __repr__(self):\n",
    "        return '{}'.format(self.result)\n",
    "\n",
    "    def add_node(self, val, node):\n",
    "        self.tree[val] = node\n",
    "\n",
    "    def predict(self, features):\n",
    "        if self.root is True:\n",
    "            return self.label\n",
    "        return self.tree[features[self.feature]].predict(features)\n",
    "    \n",
    "class DTree:\n",
    "    def __init__(self, epsilon=0.1):\n",
    "        self.epsilon = epsilon\n",
    "        self._tree = {}\n",
    "\n",
    "    # 熵\n",
    "    @staticmethod\n",
    "    def calc_ent(datasets):\n",
    "        data_length = len(datasets)\n",
    "        label_count = {}\n",
    "        for i in range(data_length):\n",
    "            label = datasets[i][-1]\n",
    "            if label not in label_count:\n",
    "                label_count[label] = 0\n",
    "            label_count[label] += 1\n",
    "        ent = -sum([(p/data_length)*log(p/data_length, 2) for p in label_count.values()])\n",
    "        return ent\n",
    "\n",
    "    # 经验条件熵\n",
    "    def cond_ent(self, datasets, axis=0):\n",
    "        data_length = len(datasets)\n",
    "        feature_sets = {}\n",
    "        for i in range(data_length):\n",
    "            feature = datasets[i][axis]\n",
    "            if feature not in feature_sets:\n",
    "                feature_sets[feature] = []\n",
    "            feature_sets[feature].append(datasets[i])\n",
    "        cond_ent = sum([(len(p)/data_length)*self.calc_ent(p) for p in feature_sets.values()])\n",
    "        return cond_ent\n",
    "\n",
    "    # 信息增益\n",
    "    @staticmethod\n",
    "    def info_gain(ent, cond_ent):\n",
    "        return ent - cond_ent\n",
    "\n",
    "    def info_gain_train(self, datasets):\n",
    "        count = len(datasets[0]) - 1\n",
    "        ent = self.calc_ent(datasets)\n",
    "        best_feature = []\n",
    "        for c in range(count):\n",
    "            c_info_gain = self.info_gain(ent, self.cond_ent(datasets, axis=c))\n",
    "            best_feature.append((c, c_info_gain))\n",
    "        # 比较大小\n",
    "        best_ = max(best_feature, key=lambda x: x[-1])\n",
    "        return best_\n",
    "\n",
    "    def train(self, train_data):\n",
    "        \"\"\"\n",
    "        input:数据集D(DataFrame格式)，特征集A，阈值eta\n",
    "        output:决策树T\n",
    "        \"\"\"\n",
    "        _, y_train, features = train_data.iloc[:, :-1], train_data.iloc[:, -1], train_data.columns[:-1]\n",
    "        # 1,若D中实例属于同一类Ck，则T为单节点树，并将类Ck作为结点的类标记，返回T\n",
    "        if len(y_train.value_counts()) == 1:\n",
    "            return Node(root=True,\n",
    "                        label=y_train.iloc[0])\n",
    "\n",
    "        # 2, 若A为空，则T为单节点树，将D中实例树最大的类Ck作为该节点的类标记，返回T\n",
    "        if len(features) == 0:\n",
    "            return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0])\n",
    "\n",
    "        # 3,计算最大信息增益 同5.1,Ag为信息增益最大的特征\n",
    "        max_feature, max_info_gain = self.info_gain_train(np.array(train_data))\n",
    "        max_feature_name = features[max_feature]\n",
    "\n",
    "        # 4,Ag的信息增益小于阈值eta,则置T为单节点树，并将D中是实例数最大的类Ck作为该节点的类标记，返回T\n",
    "        if max_info_gain < self.epsilon:\n",
    "            return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0])\n",
    "\n",
    "        # 5,构建Ag子集\n",
    "        node_tree = Node(root=False, feature_name=max_feature_name, feature=max_feature)\n",
    "\n",
    "        feature_list = train_data[max_feature_name].value_counts().index\n",
    "        for f in feature_list:\n",
    "            sub_train_df = train_data.loc[train_data[max_feature_name] == f].drop([max_feature_name], axis=1)\n",
    "\n",
    "            # 6, 递归生成树\n",
    "            sub_tree = self.train(sub_train_df)\n",
    "            node_tree.add_node(f, sub_tree)\n",
    "\n",
    "        # pprint.pprint(node_tree.tree)\n",
    "        return node_tree\n",
    "\n",
    "    def fit(self, train_data):\n",
    "        self._tree = self.train(train_data)\n",
    "        return self._tree\n",
    "\n",
    "    def predict(self, X_test):\n",
    "        return self._tree.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets, labels = create_data()\n",
    "data_df = pd.DataFrame(datasets, columns=labels)\n",
    "dt = DTree()\n",
    "tree = dt.fit(data_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'label:': None, 'feature': 2, 'tree': {'否': {'label:': None, 'feature': 1, 'tree': {'否': {'label:': '否', 'feature': None, 'tree': {}}, '是': {'label:': '是', 'feature': None, 'tree': {}}}}, '是': {'label:': '是', 'feature': None, 'tree': {}}}}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'否'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dt.predict(['老年', '否', '否', '一般'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## sklearn.tree.DecisionTreeClassifier\n",
    "\n",
    "### criterion : string, optional (default=”gini”)\n",
    "The function to measure the quality of a split. Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data\n",
    "def create_data():\n",
    "    iris = load_iris()\n",
    "    df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
    "    df['label'] = iris.target\n",
    "    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
    "    data = np.array(df.iloc[:100, [0, 1, -1]])\n",
    "    # print(data)\n",
    "    return data[:,:2], data[:,-1]\n",
    "\n",
    "X, y = create_data()\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "from sklearn.tree import export_graphviz\n",
    "import graphviz\n",
    "\n",
    "import os\n",
    "os.environ[\"PATH\"] += os.pathsep + 'E:/DevSoft/Graphviz2.38/bin'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n",
       "            max_features=None, max_leaf_nodes=None,\n",
       "            min_impurity_decrease=0.0, min_impurity_split=None,\n",
       "            min_samples_leaf=1, min_samples_split=2,\n",
       "            min_weight_fraction_leaf=0.0, presort=False, random_state=None,\n",
       "            splitter='best')"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf = DecisionTreeClassifier()\n",
    "clf.fit(X_train, y_train,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_pic = export_graphviz(clf, out_file=\"mytree.pdf\")\n",
    "with open('mytree.pdf') as f:\n",
    "    dot_graph = f.read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: Tree Pages: 1 -->\r\n",
       "<svg width=\"569pt\" height=\"477pt\"\r\n",
       " viewBox=\"0.00 0.00 569.00 477.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 473)\">\r\n",
       "<title>Tree</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-473 565,-473 565,4 -4,4\"/>\r\n",
       "<!-- 0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"391.5,-469 281.5,-469 281.5,-401 391.5,-401 391.5,-469\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"336.5\" y=\"-453.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[0] &lt;= 5.45</text>\r\n",
       "<text text-anchor=\"middle\" x=\"336.5\" y=\"-438.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.493</text>\r\n",
       "<text text-anchor=\"middle\" x=\"336.5\" y=\"-423.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 70</text>\r\n",
       "<text text-anchor=\"middle\" x=\"336.5\" y=\"-408.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [39, 31]</text>\r\n",
       "</g>\r\n",
       "<!-- 1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"327.5,-365 225.5,-365 225.5,-297 327.5,-297 327.5,-365\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"276.5\" y=\"-349.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[1] &lt;= 2.8</text>\r\n",
       "<text text-anchor=\"middle\" x=\"276.5\" y=\"-334.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.18</text>\r\n",
       "<text text-anchor=\"middle\" x=\"276.5\" y=\"-319.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 40</text>\r\n",
       "<text text-anchor=\"middle\" x=\"276.5\" y=\"-304.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [36, 4]</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M317.02,-400.884C311.99,-392.332 306.508,-383.013 301.248,-374.072\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"304.175,-372.144 296.088,-365.299 298.141,-375.693 304.175,-372.144\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"289.814\" y=\"-385.799\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">True</text>\r\n",
       "</g>\r\n",
       "<!-- 10 -->\r\n",
       "<g id=\"node11\" class=\"node\"><title>10</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"447.5,-365 345.5,-365 345.5,-297 447.5,-297 447.5,-365\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"396.5\" y=\"-349.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[1] &lt;= 3.45</text>\r\n",
       "<text text-anchor=\"middle\" x=\"396.5\" y=\"-334.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.18</text>\r\n",
       "<text text-anchor=\"middle\" x=\"396.5\" y=\"-319.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 30</text>\r\n",
       "<text text-anchor=\"middle\" x=\"396.5\" y=\"-304.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [3, 27]</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;10 -->\r\n",
       "<g id=\"edge10\" class=\"edge\"><title>0&#45;&gt;10</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M355.98,-400.884C361.01,-392.332 366.492,-383.013 371.752,-374.072\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"374.859,-375.693 376.912,-365.299 368.825,-372.144 374.859,-375.693\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"383.186\" y=\"-385.799\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">False</text>\r\n",
       "</g>\r\n",
       "<!-- 2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"207,-261 112,-261 112,-193 207,-193 207,-261\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"159.5\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[0] &lt;= 4.7</text>\r\n",
       "<text text-anchor=\"middle\" x=\"159.5\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.375</text>\r\n",
       "<text text-anchor=\"middle\" x=\"159.5\" y=\"-215.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 4</text>\r\n",
       "<text text-anchor=\"middle\" x=\"159.5\" y=\"-200.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [1, 3]</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>1&#45;&gt;2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M238.513,-296.884C227.99,-287.709 216.452,-277.65 205.523,-268.123\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"207.534,-265.233 197.696,-261.299 202.934,-270.509 207.534,-265.233\"/>\r\n",
       "</g>\r\n",
       "<!-- 5 -->\r\n",
       "<g id=\"node6\" class=\"node\"><title>5</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"327.5,-261 225.5,-261 225.5,-193 327.5,-193 327.5,-261\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"276.5\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[1] &lt;= 3.05</text>\r\n",
       "<text text-anchor=\"middle\" x=\"276.5\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.054</text>\r\n",
       "<text text-anchor=\"middle\" x=\"276.5\" y=\"-215.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 36</text>\r\n",
       "<text text-anchor=\"middle\" x=\"276.5\" y=\"-200.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [35, 1]</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;5 -->\r\n",
       "<g id=\"edge5\" class=\"edge\"><title>1&#45;&gt;5</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M276.5,-296.884C276.5,-288.778 276.5,-279.982 276.5,-271.472\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"280,-271.299 276.5,-261.299 273,-271.299 280,-271.299\"/>\r\n",
       "</g>\r\n",
       "<!-- 3 -->\r\n",
       "<g id=\"node4\" class=\"node\"><title>3</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"95,-149.5 0,-149.5 0,-96.5 95,-96.5 95,-149.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"47.5\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"47.5\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"47.5\" y=\"-104.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [1, 0]</text>\r\n",
       "</g>\r\n",
       "<!-- 2&#45;&gt;3 -->\r\n",
       "<g id=\"edge3\" class=\"edge\"><title>2&#45;&gt;3</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M123.137,-192.884C110.336,-181.226 95.9681,-168.141 83.2209,-156.532\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"85.272,-153.666 75.5219,-149.52 80.5587,-158.841 85.272,-153.666\"/>\r\n",
       "</g>\r\n",
       "<!-- 4 -->\r\n",
       "<g id=\"node5\" class=\"node\"><title>4</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"208,-149.5 113,-149.5 113,-96.5 208,-96.5 208,-149.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"160.5\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"160.5\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"160.5\" y=\"-104.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 3]</text>\r\n",
       "</g>\r\n",
       "<!-- 2&#45;&gt;4 -->\r\n",
       "<g id=\"edge4\" class=\"edge\"><title>2&#45;&gt;4</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M159.825,-192.884C159.928,-182.326 160.043,-170.597 160.148,-159.854\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"163.652,-159.554 160.25,-149.52 156.652,-159.485 163.652,-159.554\"/>\r\n",
       "</g>\r\n",
       "<!-- 6 -->\r\n",
       "<g id=\"node7\" class=\"node\"><title>6</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"322,-157 227,-157 227,-89 322,-89 322,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"274.5\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[0] &lt;= 5.15</text>\r\n",
       "<text text-anchor=\"middle\" x=\"274.5\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.32</text>\r\n",
       "<text text-anchor=\"middle\" x=\"274.5\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 5</text>\r\n",
       "<text text-anchor=\"middle\" x=\"274.5\" y=\"-96.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [4, 1]</text>\r\n",
       "</g>\r\n",
       "<!-- 5&#45;&gt;6 -->\r\n",
       "<g id=\"edge6\" class=\"edge\"><title>5&#45;&gt;6</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M275.851,-192.884C275.692,-184.778 275.519,-175.982 275.352,-167.472\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"278.848,-167.229 275.153,-157.299 271.85,-167.366 278.848,-167.229\"/>\r\n",
       "</g>\r\n",
       "<!-- 9 -->\r\n",
       "<g id=\"node10\" class=\"node\"><title>9</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"442.5,-149.5 340.5,-149.5 340.5,-96.5 442.5,-96.5 442.5,-149.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"391.5\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"391.5\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 31</text>\r\n",
       "<text text-anchor=\"middle\" x=\"391.5\" y=\"-104.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [31, 0]</text>\r\n",
       "</g>\r\n",
       "<!-- 5&#45;&gt;9 -->\r\n",
       "<g id=\"edge9\" class=\"edge\"><title>5&#45;&gt;9</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M313.837,-192.884C327.105,-181.116 342.012,-167.894 355.192,-156.203\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"357.569,-158.774 362.728,-149.52 352.924,-153.537 357.569,-158.774\"/>\r\n",
       "</g>\r\n",
       "<!-- 7 -->\r\n",
       "<g id=\"node8\" class=\"node\"><title>7</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"265,-53 170,-53 170,-0 265,-0 265,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"217.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"217.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 4</text>\r\n",
       "<text text-anchor=\"middle\" x=\"217.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [4, 0]</text>\r\n",
       "</g>\r\n",
       "<!-- 6&#45;&gt;7 -->\r\n",
       "<g id=\"edge7\" class=\"edge\"><title>6&#45;&gt;7</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M254.564,-88.9485C249.267,-80.1664 243.535,-70.6629 238.198,-61.815\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"241.186,-59.9919 233.024,-53.2367 235.192,-63.6074 241.186,-59.9919\"/>\r\n",
       "</g>\r\n",
       "<!-- 8 -->\r\n",
       "<g id=\"node9\" class=\"node\"><title>8</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"378,-53 283,-53 283,-0 378,-0 378,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"330.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"330.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"330.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 1]</text>\r\n",
       "</g>\r\n",
       "<!-- 6&#45;&gt;8 -->\r\n",
       "<g id=\"edge8\" class=\"edge\"><title>6&#45;&gt;8</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M294.086,-88.9485C299.236,-80.2579 304.805,-70.8608 310.001,-62.0917\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"313.162,-63.624 315.249,-53.2367 307.14,-60.0553 313.162,-63.624\"/>\r\n",
       "</g>\r\n",
       "<!-- 11 -->\r\n",
       "<g id=\"node12\" class=\"node\"><title>11</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"447.5,-253.5 345.5,-253.5 345.5,-200.5 447.5,-200.5 447.5,-253.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"396.5\" y=\"-238.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"396.5\" y=\"-223.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 27</text>\r\n",
       "<text text-anchor=\"middle\" x=\"396.5\" y=\"-208.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 27]</text>\r\n",
       "</g>\r\n",
       "<!-- 10&#45;&gt;11 -->\r\n",
       "<g id=\"edge11\" class=\"edge\"><title>10&#45;&gt;11</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M396.5,-296.884C396.5,-286.326 396.5,-274.597 396.5,-263.854\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"400,-263.52 396.5,-253.52 393,-263.52 400,-263.52\"/>\r\n",
       "</g>\r\n",
       "<!-- 12 -->\r\n",
       "<g id=\"node13\" class=\"node\"><title>12</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"561,-253.5 466,-253.5 466,-200.5 561,-200.5 561,-253.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"513.5\" y=\"-238.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"513.5\" y=\"-223.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"513.5\" y=\"-208.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [3, 0]</text>\r\n",
       "</g>\r\n",
       "<!-- 10&#45;&gt;12 -->\r\n",
       "<g id=\"edge12\" class=\"edge\"><title>10&#45;&gt;12</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M434.487,-296.884C447.985,-285.116 463.151,-271.894 476.561,-260.203\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"478.989,-262.73 484.227,-253.52 474.389,-257.453 478.989,-262.73\"/>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.files.Source at 0x189c94e0>"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graphviz.Source(dot_graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
